import os
import re
import json
import time
import numpy as np
from datetime import datetime
from tqdm import tqdm
from call_gpt import call_gpt
from prompts.time_series_prediction import standard_prompt, llm_ar_prompt
import argparse
import tiktoken

parser = argparse.ArgumentParser(description='Run time series prediction task')

parser.add_argument('--dataset_dir', type=str, default='../dataset', help='Directory of the dataset')
parser.add_argument('--model', type=str, default='gpt-3.5-turbo', help='Model to use')
parser.add_argument('--method', type=str, default='standard', choices=['sma', 'standard', 'llm-ar', 'llm-ar-sc', 'llmtime', 'mip'], help='Method to use')
parser.add_argument('--steps', type=int, default=10, help='Number of values to predict')
parser.add_argument('--max_input_length', type=int, default=100, help='Maximum number of input sequence length')
parser.add_argument('--sma_window_size', type=int, default=20, help='Window size of average')
parser.add_argument('--log_dir', type=str, default='log', help='Directory for logs')
parser.add_argument('--k_samples', type=int, default=5, help='Number of samples for cot-sc method')

args = parser.parse_args()
dataset_dir = args.dataset_dir
task = 'time_series_prediction'
model = args.model
method = args.method
steps = args.steps
max_input_length = args.max_input_length
sma_window_size = args.sma_window_size
k_samples = args.k_samples

method = method.lower()

log_dir_base = os.path.join(args.log_dir, task)

current_time_str = datetime.now().strftime('%Y%m%d_%H%M%S')

if not os.path.exists(log_dir_base):
    os.makedirs(log_dir_base)

log_filename = os.path.join(log_dir_base, f'{model}_{method}_{current_time_str}.csv')
log_file = open(log_filename, 'w', encoding='utf8')
log_file.write('predict_id,next_values,mae,mape\n')
log_file.flush()


encoding = tiktoken.encoding_for_model(model)


def get_avg_tokens_per_step(input_str, time_sep=','):
    tokens = encoding.encode(input_str)
    input_tokens = len(tokens)
    input_steps = len(input_str.split(time_sep))
    tokens_per_step = input_tokens / input_steps
    return tokens_per_step


def calculate_mae(actual, predicted):
    """
    Calculate Mean Absolute Error (MAE) between two lists of floats.

    :param actual: List of actual values.
    :param predicted: List of predicted values.
    :return: MAE as a float.
    """
    if len(actual) != len(predicted):
        raise ValueError("The length of actual and predicted lists must be the same.")

    mae = sum(abs(a - p) for a, p in zip(actual, predicted)) / len(actual)
    return mae


def calculate_mape(actual, predicted):
    """
    Calculate Mean Absolute Percentage Error (MAPE) between two lists of floats.

    :param actual: List of actual values.
    :param predicted: List of predicted values.
    :return: MAPE as a float.
    """
    if len(actual) != len(predicted):
        raise ValueError("The length of actual and predicted lists must be the same.")

    mape = sum(abs((a - p) / a) for a, p in zip(actual, predicted) if a != 0) / len(actual)
    return mape * 100


metadata_path = os.path.join(dataset_dir, task, 'task.json')
with open(metadata_path, 'r', encoding='utf8') as f:
    metadata = json.load(f)

mae_list = []
mape_list = []
std_list = []

for sid, item in enumerate(tqdm(metadata)):
    temperature = 0.7
    input_sequence = item['input'][-max_input_length:]

    if method == 'sma':
        window = input_sequence[:]
        for _ in range(steps):
            window = window[-sma_window_size:]
            next_value = np.mean(window)
            window.append(next_value)
        pred_sequence = window[-steps:]
    elif method == 'standard':
        input_str = ','.join(list(map(str, input_sequence))) + ','
        prompt = standard_prompt.format(sequence=input_str)
        avg_tokens = get_avg_tokens_per_step(input_str)
        steps = int(steps * 1.4) # Enlarger steps to prevent being cut off
        max_tokens = round(steps * avg_tokens)

        result = call_gpt(
            prompt.format(sequence=item['input']), 
            model,
            temperature=temperature,
            max_tokens=max_tokens,
        )

        log_dir = os.path.join(log_dir_base, f'{model}_{method}_{current_time_str}')
        if not os.path.exists(log_dir):
            os.makedirs(log_dir)
        with open(os.path.join(log_dir, str(sid) + '.txt'), 'w', encoding='utf8') as f:
            f.write(prompt.format(sequence=item['input']) + '\n\nAnswer:\n')
            f.write(result)

        pred_sequence = []
        for num_str in result.split(','):
            try:
                pred_sequence.append(float(num_str.strip()))
            except:
                pass
        time.sleep(3)
    elif method == 'llm-ar' or method == 'llm-ar-sc':       
        pred_sequence = []
        input_str = ','.join(list(map(str, input_sequence))) + ','
        for _ in range(steps):
            avg_tokens = get_avg_tokens_per_step(input_str)
            max_tokens = round(1.4 * avg_tokens)
            log_dir = os.path.join(log_dir_base, f'{model}_{method}_{current_time_str}')
            if not os.path.exists(log_dir):
                os.makedirs(log_dir)
            if method == 'llm-ar':
                result = call_gpt(
                    llm_ar_prompt.format(sequence=input_str), 
                    model,
                    temperature=temperature,
                    max_tokens=max_tokens,
                )
                with open(os.path.join(log_dir, str(sid) + '.txt'), 'a+', encoding='utf8') as f:
                    f.write(llm_ar_prompt.format(sequence=input_str) + '\nAnswer:\n')
                    f.write(result + '\n\n\n')
                print('max tokens:', max_tokens)
                print(result)
                try:
                    next_value = float(result.replace(',', '').strip())
                    pred_sequence.append(next_value)
                except:
                    break
                input_str += result + ','
            elif method == 'llm-ar-sc':
                temperature = 0.7
                sampled_values = []
                for _ in range(k_samples):
                    result = call_gpt(
                        llm_ar_prompt.format(sequence=input_str), 
                        model,
                        temperature=temperature,
                        max_tokens=max_tokens,
                    )
                    print(result)
                    try:
                        next_value = float(result.replace(',', '').strip())
                        sampled_values.append(next_value)
                    except:
                        continue
                    time.sleep(3)
                print('average value:')
                print(next_value)
                next_value = np.mean(sampled_values)
                input_str += str(next_value) + ','
                std_var = np.std(sampled_values)
                std_list.append(std_var)
            time.sleep(3)

    # Post precessing
    if not pred_sequence:
        pred_sequence = [input_sequence[-1]] * steps
    elif len(pred_sequence) < steps:
        pred_sequence.extend([pred_sequence[-1]] * (steps - len(pred_sequence)))
    elif len(pred_sequence) > steps:
        pred_sequence = pred_sequence[:steps]

    ground_truth = item['output']
    ground_truth = ground_truth[:steps]
    pred_sequence = pred_sequence[:len(ground_truth)]
    mae = calculate_mae(ground_truth, pred_sequence)
    mape = calculate_mape(ground_truth, pred_sequence)
    mae_list.append(mae)
    mape_list.append(mape)

    log_file.write(f'{sid},{" ".join(list(map(str, pred_sequence)))},{mae},{mape}\n')
    log_file.flush()


# Calculate and display the average MAE and MAPE
avg_mae = sum(mae_list) / len(mae_list)
avg_mape = sum(mape_list) / len(mape_list)

# Print the performance report
print(f"Performance Report - Model: {model}, Method: {method}")
print(f"Total Samples: {len(mae_list)}")
print(f"Average MAE: {avg_mae}")
print(f"Average MAPE: {avg_mape}%")

if method == 'llm-ar-sc':
    avg_std = sum(std_list) / len(std_list)
    print(f"Average standard deviation: {avg_std}")
